{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Auxiliary Classifier Generative Adversarial Networks (AC-GAN) on MNIST\n",
    "\n",
    "Modified from Keras examples:\n",
    "\n",
    "https://github.com/keras-team/keras/blob/master/examples/mnist_acgan.py\n",
    "\n",
    "Original repo: https://github.com/lukedeo/keras-acgan\n",
    "\n",
    "Here, we use the tensorflow backend. The learning rate is decreased."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "KERAS_MODEL_FILEPATH = '../../demos/data/mnist_acgan/mnist_acgan.h5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/home/leon/miniconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "from PIL import Image\n",
    "from keras.datasets import mnist\n",
    "from keras.layers import Input, Dense, Reshape, Flatten, Embedding, Dropout\n",
    "from keras.layers import Multiply, LeakyReLU, UpSampling2D, Conv2D\n",
    "from keras.models import Sequential, Model\n",
    "from keras.optimizers import Adam\n",
    "import numpy as np\n",
    "\n",
    "np.random.seed(1337)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_generator(latent_size):\n",
    "    # we will map a pair of (z, L), where z is a latent vector and L is a\n",
    "    # label drawn from P_c, to image space (..., 28, 28, 1)\n",
    "    cnn = Sequential()\n",
    "\n",
    "    cnn.add(Dense(1024, input_dim=latent_size, activation='relu'))\n",
    "    cnn.add(Dense(128 * 7 * 7, activation='relu'))\n",
    "    cnn.add(Reshape((7, 7, 128)))\n",
    "\n",
    "    # upsample to (..., 14, 14)\n",
    "    cnn.add(UpSampling2D(size=(2, 2)))\n",
    "    cnn.add(Conv2D(256, 5, padding='same', activation='relu',\n",
    "                   kernel_initializer='glorot_normal'))\n",
    "\n",
    "    # upsample to (..., 28, 28)\n",
    "    cnn.add(UpSampling2D(size=(2, 2)))\n",
    "    cnn.add(Conv2D(128, 5, padding='same', activation='relu',\n",
    "                   kernel_initializer='glorot_normal'))\n",
    "\n",
    "    # take a channel axis reduction\n",
    "    cnn.add(Conv2D(1, 2, padding='same', activation='tanh',\n",
    "                   kernel_initializer='glorot_normal'))\n",
    "\n",
    "    # this is the z space commonly refered to in GAN papers\n",
    "    latent = Input(shape=(latent_size,))\n",
    "\n",
    "    # this will be our label\n",
    "    image_class = Input(shape=(1,), dtype='int32')\n",
    "\n",
    "    # 10 classes in MNIST\n",
    "    emb = Embedding(10, latent_size, embeddings_initializer='glorot_normal')(image_class)\n",
    "    cls = Flatten()(emb)\n",
    "\n",
    "    # hadamard product between z-space and a class conditional embedding\n",
    "    h = Multiply()([latent, cls])\n",
    "\n",
    "    fake_image = cnn(h)\n",
    "\n",
    "    return Model([latent, image_class], fake_image)\n",
    "\n",
    "\n",
    "def build_discriminator():\n",
    "    # build a relatively standard conv net, with LeakyReLUs as suggested in\n",
    "    # the reference paper\n",
    "    cnn = Sequential()\n",
    "\n",
    "    cnn.add(Conv2D(32, 3, padding='same', strides=2, input_shape=(28, 28, 1)))\n",
    "    cnn.add(LeakyReLU())\n",
    "    cnn.add(Dropout(0.3))\n",
    "\n",
    "    cnn.add(Conv2D(64, 3, padding='same', strides=1))\n",
    "    cnn.add(LeakyReLU())\n",
    "    cnn.add(Dropout(0.3))\n",
    "\n",
    "    cnn.add(Conv2D(128, 3, padding='same', strides=2))\n",
    "    cnn.add(LeakyReLU())\n",
    "    cnn.add(Dropout(0.3))\n",
    "\n",
    "    cnn.add(Conv2D(256, 3, padding='same', strides=1))\n",
    "    cnn.add(LeakyReLU())\n",
    "    cnn.add(Dropout(0.3))\n",
    "\n",
    "    cnn.add(Flatten())\n",
    "\n",
    "    image = Input(shape=(28, 28, 1))\n",
    "\n",
    "    features = cnn(image)\n",
    "\n",
    "    # first output (name=generation) is whether or not the discriminator\n",
    "    # thinks the image that is being shown is fake, and the second output\n",
    "    # (name=auxiliary) is the class that the discriminator thinks the image\n",
    "    # belongs to.\n",
    "    fake = Dense(1, activation='sigmoid', name='generation')(features)\n",
    "    aux = Dense(10, activation='softmax', name='auxiliary')(features)\n",
    "\n",
    "    return Model(image, [fake, aux])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# batch and latent size taken from the paper\n",
    "epochs = 50\n",
    "batch_size = 100\n",
    "latent_size = 100\n",
    "\n",
    "# Adam parameters suggested in https://arxiv.org/abs/1511.06434\n",
    "# decreased learning rate from repo settings\n",
    "adam_lr = 0.00005\n",
    "adam_beta_1 = 0.5\n",
    "\n",
    "# build the discriminator\n",
    "discriminator = build_discriminator()\n",
    "discriminator.compile(\n",
    "    optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),\n",
    "    loss=['binary_crossentropy', 'sparse_categorical_crossentropy']\n",
    ")\n",
    "\n",
    "# build the generator\n",
    "generator = build_generator(latent_size)\n",
    "generator.compile(\n",
    "    optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),\n",
    "    loss='binary_crossentropy'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_3 (InputLayer)            (None, 1)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding_1 (Embedding)         (None, 1, 100)       1000        input_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "input_2 (InputLayer)            (None, 100)          0                                            \n",
      "__________________________________________________________________________________________________\n",
      "flatten_2 (Flatten)             (None, 100)          0           embedding_1[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "multiply_1 (Multiply)           (None, 100)          0           input_2[0][0]                    \n",
      "                                                                 flatten_2[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "sequential_2 (Sequential)       (None, 28, 28, 1)    8171521     multiply_1[0][0]                 \n",
      "==================================================================================================\n",
      "Total params: 8,172,521\n",
      "Trainable params: 8,172,521\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "generator.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            (None, 28, 28, 1)    0                                            \n",
      "__________________________________________________________________________________________________\n",
      "sequential_1 (Sequential)       (None, 12544)        387840      input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "generation (Dense)              (None, 1)            12545       sequential_1[1][0]               \n",
      "__________________________________________________________________________________________________\n",
      "auxiliary (Dense)               (None, 10)           125450      sequential_1[1][0]               \n",
      "==================================================================================================\n",
      "Total params: 525,835\n",
      "Trainable params: 525,835\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "discriminator.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent = Input(shape=(latent_size, ))\n",
    "image_class = Input(shape=(1,), dtype='int32')\n",
    "\n",
    "# get a fake image\n",
    "fake = generator([latent, image_class])\n",
    "\n",
    "# we only want to be able to train generation for the combined model\n",
    "discriminator.trainable = False\n",
    "fake, aux = discriminator(fake)\n",
    "combined = Model([latent, image_class], [fake, aux])\n",
    "\n",
    "combined.compile(\n",
    "    optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),\n",
    "    loss=['binary_crossentropy', 'sparse_categorical_crossentropy']\n",
    ")\n",
    "\n",
    "# get our mnist data, and force it to be of shape (..., 28, 28, 1) with\n",
    "# range [-1, 1]\n",
    "(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
    "X_train = (X_train.astype(np.float32) - 127.5) / 127.5\n",
    "X_train = np.expand_dims(X_train, axis=3)\n",
    "X_test = (X_test.astype(np.float32) - 127.5) / 127.5\n",
    "X_test = np.expand_dims(X_test, axis=3)\n",
    "\n",
    "num_train, num_test = X_train.shape[0], X_test.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch\tL_s(G)\tL_s(G)\tL_s(D)\tL_s(D)\tL_c(G)\tL_c(G)\tL_c(D)\tL_c(D)\n",
      "1\t"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/leon/miniconda3/lib/python3.6/site-packages/keras/engine/training.py:973: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?\n",
      "  'Discrepancy between trainable weights and collected trainable'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.1100\t0.9312\t0.5795\t0.7355\t2.2971\t1.2854\t1.6925\t0.9315\n",
      "2\t1.0730\t0.5990\t0.6369\t0.8421\t0.5163\t0.1205\t0.5372\t0.2955\n",
      "3\t1.1886\t0.8601\t0.5789\t0.6166\t0.1651\t0.0631\t0.3060\t0.2019\n",
      "4\t1.2439\t0.9160\t0.5571\t0.4781\t0.1322\t0.0678\t0.2684\t0.1844\n",
      "5\t1.2500\t0.7964\t0.5549\t0.6434\t0.1182\t0.0712\t0.2339\t0.1710\n",
      "6\t1.2235\t1.0951\t0.5536\t0.5279\t0.0581\t0.0745\t0.1744\t0.1409\n",
      "7\t1.0737\t1.2428\t0.6150\t0.5226\t0.0483\t0.0046\t0.1400\t0.0813\n",
      "8\t0.9644\t0.9570\t0.6262\t0.5729\t0.0208\t0.0086\t0.1069\t0.0717\n",
      "9\t0.9787\t0.7053\t0.6365\t0.7311\t0.0234\t0.0086\t0.0990\t0.0628\n",
      "10\t0.9373\t0.7018\t0.6392\t0.6134\t0.0175\t0.0028\t0.0875\t0.0541\n",
      "11\t0.8988\t0.8082\t0.6409\t0.5941\t0.0157\t0.0010\t0.0796\t0.0480\n",
      "12\t0.8516\t0.7033\t0.6608\t0.6815\t0.0122\t0.0092\t0.0707\t0.0479\n",
      "13\t0.8081\t0.7507\t0.6802\t0.7611\t0.0116\t0.0074\t0.0681\t0.0461\n",
      "14\t0.7778\t0.7105\t0.6895\t0.6852\t0.0113\t0.0051\t0.0662\t0.0425\n",
      "15\t0.7542\t0.6691\t0.6919\t0.6839\t0.0092\t0.0033\t0.0610\t0.0375\n",
      "16\t0.7228\t0.7059\t0.7044\t0.6954\t0.0078\t0.0032\t0.0566\t0.0359\n",
      "17\t0.7198\t0.6816\t0.7035\t0.7041\t0.0067\t0.0147\t0.0536\t0.0381\n",
      "18\t0.7124\t0.6995\t0.7046\t0.7041\t0.0057\t0.0027\t0.0504\t0.0306\n",
      "19\t0.7120\t0.7010\t0.7029\t0.6979\t0.0047\t0.0021\t0.0475\t0.0296\n",
      "20\t0.7089\t0.7003\t0.7036\t0.6976\t0.0046\t0.0036\t0.0446\t0.0286\n",
      "21\t0.7110\t0.6781\t0.7022\t0.7028\t0.0046\t0.0008\t0.0436\t0.0259\n",
      "22\t0.7097\t0.6950\t0.7020\t0.7011\t0.0040\t0.0040\t0.0412\t0.0267\n",
      "23\t0.7120\t0.6829\t0.7020\t0.7112\t0.0040\t0.0014\t0.0398\t0.0240\n",
      "24\t0.7111\t0.6974\t0.7007\t0.7322\t0.0035\t0.0035\t0.0377\t0.0247\n",
      "25\t0.7090\t0.6597\t0.7012\t0.7381\t0.0032\t0.0048\t0.0360\t0.0244\n",
      "26\t0.7066\t0.6888\t0.7012\t0.7308\t0.0030\t0.0010\t0.0349\t0.0216\n",
      "27\t0.7070\t0.6634\t0.7015\t0.7180\t0.0030\t0.0006\t0.0336\t0.0207\n",
      "28\t0.7052\t0.6593\t0.7008\t0.7249\t0.0028\t0.0014\t0.0326\t0.0204\n",
      "29\t0.7081\t0.6809\t0.7002\t0.7060\t0.0028\t0.0006\t0.0322\t0.0193\n",
      "30\t0.7069\t0.6645\t0.6992\t0.7241\t0.0025\t0.0012\t0.0309\t0.0190\n",
      "31\t0.7059\t0.6731\t0.6996\t0.7154\t0.0026\t0.0003\t0.0299\t0.0184\n",
      "32\t0.7054\t0.7072\t0.6986\t0.7083\t0.0026\t0.0003\t0.0300\t0.0178\n",
      "33\t0.7056\t0.6759\t0.6997\t0.7130\t0.0025\t0.0004\t0.0286\t0.0173\n",
      "34\t0.7046\t0.6843\t0.6990\t0.7119\t0.0023\t0.0002\t0.0277\t0.0173\n",
      "35\t0.7064\t0.7097\t0.6990\t0.7004\t0.0021\t0.0007\t0.0275\t0.0173\n",
      "36\t0.7063\t0.7225\t0.6985\t0.7003\t0.0021\t0.0001\t0.0260\t0.0161\n",
      "37\t0.7047\t0.6791\t0.6985\t0.7123\t0.0020\t0.0002\t0.0262\t0.0162\n",
      "38\t0.7045\t0.6776\t0.6985\t0.7182\t0.0020\t0.0007\t0.0252\t0.0162\n",
      "39\t0.7060\t0.6918\t0.6983\t0.7046\t0.0020\t0.0004\t0.0260\t0.0159\n",
      "40\t0.7042\t0.6993\t0.6978\t0.7035\t0.0019\t0.0003\t0.0248\t0.0152\n",
      "41\t0.7037\t0.6908\t0.6979\t0.7188\t0.0018\t0.0004\t0.0241\t0.0156\n",
      "42\t0.7047\t0.6831\t0.6977\t0.7045\t0.0018\t0.0001\t0.0243\t0.0148\n",
      "43\t0.7035\t0.7004\t0.6978\t0.7097\t0.0017\t0.0002\t0.0234\t0.0149\n",
      "44\t0.7046\t0.6926\t0.6970\t0.7058\t0.0017\t0.0001\t0.0230\t0.0147\n",
      "45\t0.7036\t0.6849\t0.6973\t0.7090\t0.0017\t0.0004\t0.0224\t0.0141\n",
      "46\t0.7045\t0.6993\t0.6976\t0.6973\t0.0017\t0.0002\t0.0221\t0.0141\n",
      "47\t0.7030\t0.6955\t0.6973\t0.7056\t0.0016\t0.0001\t0.0223\t0.0140\n",
      "48\t0.7030\t0.6997\t0.6971\t0.7136\t0.0015\t0.0001\t0.0213\t0.0145\n",
      "49\t0.7030\t0.6788\t0.6971\t0.7100\t0.0015\t0.0001\t0.0213\t0.0139\n",
      "50\t0.7038\t0.6698\t0.6973\t0.6930\t0.0015\t0.0002\t0.0212\t0.0136\n",
      "done.\n"
     ]
    }
   ],
   "source": [
    "print('Epoch\\tL_s(G)\\tL_s(G)\\tL_s(D)\\tL_s(D)\\tL_c(G)\\tL_c(G)\\tL_c(D)\\tL_c(D)')\n",
    "for epoch in range(epochs):\n",
    "    print(epoch + 1, end='\\t', flush=True)\n",
    "\n",
    "    num_batches = int(X_train.shape[0] / batch_size)\n",
    "\n",
    "    epoch_gen_loss = []\n",
    "    epoch_disc_loss = []\n",
    "\n",
    "    for index in range(num_batches):\n",
    "        # generate a new batch of noise\n",
    "        noise = np.random.uniform(-1, 1, (batch_size, latent_size))\n",
    "\n",
    "        # get a batch of real images\n",
    "        image_batch = X_train[index * batch_size:(index + 1) * batch_size]\n",
    "        label_batch = y_train[index * batch_size:(index + 1) * batch_size]\n",
    "\n",
    "        # sample some labels from p_c\n",
    "        sampled_labels = np.random.randint(0, 10, batch_size)\n",
    "\n",
    "        # generate a batch of fake images, using the generated labels as a\n",
    "        # conditioner. We reshape the sampled labels to be\n",
    "        # (batch_size, 1) so that we can feed them into the embedding\n",
    "        # layer as a length one sequence\n",
    "        generated_images = generator.predict(\n",
    "            [noise, sampled_labels.reshape((-1, 1))], verbose=0)\n",
    "\n",
    "        X = np.concatenate((image_batch, generated_images))\n",
    "        y = np.array([1] * batch_size + [0] * batch_size)\n",
    "        aux_y = np.concatenate((label_batch, sampled_labels), axis=0)\n",
    "\n",
    "        # see if the discriminator can figure itself out...\n",
    "        epoch_disc_loss.append(discriminator.train_on_batch(X, [y, aux_y]))\n",
    "\n",
    "        # make new noise. we generate 2 * batch size here such that we have\n",
    "        # the generator optimize over an identical number of images as the\n",
    "        # discriminator\n",
    "        noise = np.random.uniform(-1, 1, (2 * batch_size, latent_size))\n",
    "        sampled_labels = np.random.randint(0, 10, 2 * batch_size)\n",
    "\n",
    "        # we want to train the generator to trick the discriminator\n",
    "        # For the generator, we want all the {fake, not-fake} labels to say\n",
    "        # not-fake\n",
    "        trick = np.ones(2 * batch_size)\n",
    "\n",
    "        epoch_gen_loss.append(\n",
    "            combined.train_on_batch(\n",
    "                [noise, sampled_labels.reshape((-1, 1))],\n",
    "                [trick, sampled_labels]\n",
    "            )\n",
    "        )\n",
    "\n",
    "    # evaluate the testing loss here\n",
    "\n",
    "    # generate a new batch of noise\n",
    "    noise = np.random.uniform(-1, 1, (num_test, latent_size))\n",
    "\n",
    "    # sample some labels from p_c and generate images from them\n",
    "    sampled_labels = np.random.randint(0, 10, num_test)\n",
    "    generated_images = generator.predict(\n",
    "        [noise, sampled_labels.reshape((-1, 1))], verbose=0)\n",
    "\n",
    "    X = np.concatenate((X_test, generated_images))\n",
    "    y = np.array([1] * num_test + [0] * num_test)\n",
    "    aux_y = np.concatenate((y_test, sampled_labels), axis=0)\n",
    "\n",
    "    # see if the discriminator can figure itself out...\n",
    "    discriminator_test_loss = discriminator.evaluate(X, [y, aux_y], verbose=0)\n",
    "\n",
    "    discriminator_train_loss = np.mean(np.array(epoch_disc_loss), axis=0)\n",
    "\n",
    "    # make new noise\n",
    "    noise = np.random.uniform(-1, 1, (2 * num_test, latent_size))\n",
    "    sampled_labels = np.random.randint(0, 10, 2 * num_test)\n",
    "\n",
    "    trick = np.ones(2 * num_test)\n",
    "\n",
    "    generator_test_loss = combined.evaluate(\n",
    "        [noise, sampled_labels.reshape((-1, 1))],\n",
    "        [trick, sampled_labels],\n",
    "        verbose=0\n",
    "    )\n",
    "\n",
    "    generator_train_loss = np.mean(np.array(epoch_gen_loss), axis=0)\n",
    "    \n",
    "    print('{:.4f}\\t{:.4f}\\t{:.4f}\\t{:.4f}\\t{:.4f}\\t{:.4f}\\t{:.4f}\\t{:.4f}'.format(\n",
    "        # generation loss\n",
    "        generator_train_loss[1], generator_test_loss[1],\n",
    "        discriminator_train_loss[1], discriminator_test_loss[1],\n",
    "        # auxillary loss\n",
    "        generator_train_loss[2], generator_test_loss[2],\n",
    "        discriminator_train_loss[2], discriminator_test_loss[2],\n",
    "    ))\n",
    "\n",
    "    # save model every epoch\n",
    "    generator.save(KERAS_MODEL_FILEPATH)\n",
    "\n",
    "#     # generate some digits to display\n",
    "#     noise = np.random.uniform(-1, 1, (100, latent_size))\n",
    "\n",
    "#     sampled_labels = np.array([[i] * 10 for i in range(10)]).reshape(-1, 1)\n",
    "\n",
    "#     # get a batch to display\n",
    "#     generated_images = generator.predict([noise, sampled_labels], verbose=0)\n",
    "\n",
    "#     # arrange them into a grid\n",
    "#     img = (np.concatenate([r.reshape(-1, 28)\n",
    "#                            for r in np.split(generated_images, 10)\n",
    "#                            ], axis=-1) * 127.5 + 127.5).astype(np.uint8)\n",
    "\n",
    "#     Image.fromarray(img).save('../../demos/data/mnist_acgan/mnist_acgan_generated_{0:03d}.png'.format(epoch))\n",
    "    \n",
    "print('done.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### generate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_digit(digit=None):\n",
    "    noise = np.random.uniform(-1, 1, (1, latent_size))\n",
    "\n",
    "    sampled_label = np.array([\n",
    "            digit if digit is not None else np.random.randint(0, 10, 1)\n",
    "        ]).reshape(-1, 1)\n",
    "\n",
    "    generated_image = generator.predict(\n",
    "        [noise, sampled_label], verbose=0)\n",
    "\n",
    "    return np.squeeze((generated_image * 127.5 + 127.5).astype(np.uint8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-0.5, 27.5, 27.5, -0.5)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAB0RJREFUeJzt3U+IT/sfx/EZQwZRYuFvFsjCxIaIDQtbSzYSKWWDrCh2\nk72UlaUIERsWFlL+bOTfioWICYlRI8mf4a5+v915H9eXL3dej8f2dY/vuPc+O4vPnPPt/f79ew+Q\nZ9yf/gGAP0P8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EGp8lz/PrxPC79f7I/+QOz+EEj+EEj+EEj+E\nEj+EEj+EEj+EEj+EEj+EEj+EEj+EEj+EEj+EEj+EEj+E6vbz/PDDBgYGyv3QoUPlvnnz5l/544w5\n7vwQSvwQSvwQSvwQSvwQSvwQqvf7966+Tduru/m/z58/l/vEiRPLferUqeU+MjLyr3+mMcKru4Fm\n4odQ4odQ4odQ4odQ4odQ4odQHunlt6p+j2T58uUd/dnTp0/v6Pp07vwQSvwQSvwQSvwQSvwQSvwQ\nSvwQyjk/HWl7H8T27dsbt4cPH3b02cePH+/o+nTu/BBK/BBK/BBK/BBK/BBK/BBK/BDKe/spffny\npdw3bdpU7hcuXPjpz540aVK5Dw8Pl3t/f/9Pf/Z/nPf2A83ED6HED6HED6HED6HED6E80huu7Shv\n69at5d7JUd64cfW959y5c+UefJT3S7jzQyjxQyjxQyjxQyjxQyjxQyjxQyjn/GPc6OhouR85cqTc\nT58+3dHn9/Y2P116+/bt8tply5Z19NnU3PkhlPghlPghlPghlPghlPghlPghlHP+Me7Vq1flfvDg\nwXLv9NXu+/fvb9wGBgbKa6vfEejpaf/Z2q5P584PocQPocQPocQPocQPocQPocQPoZzzjwEfP35s\n3Hbs2FFe++nTp3JvOysfHBws9927dzdube/t//DhQ7l//fq13KdNm9a49fX1ldcmcOeHUOKHUOKH\nUOKHUOKHUOKHUOKHUM75/wJt79YfHh4u96tXrzZu165dK6+dMGFCuV+8eLHcFy9eXO5Pnz5t3EZG\nRspr79y5U+5DQ0PlvmXLlsZt6dKl5bUJ7wJw54dQ4odQ4odQ4odQ4odQ4odQjvr+Aq9fvy73PXv2\nlPulS5cat7ZHdtuOvM6fP1/uly9fLvc3b940btUjtz09PT0LFy4s9xcvXpT7kydPGrdjx46V186Y\nMaPcxwJ3fgglfgglfgglfgglfgglfgglfgjV2+lXMP9LXf2w/4r79++X+6pVq8q9OstfsmTJT/1M\n//Po0aOOru/v72/crl+/Xl47f/78ct+1a1e537x5s3E7fPhwee22bdvK/S9/5PeHfjh3fgglfggl\nfgglfgglfgglfgglfgjlef4uaHs1986dO8u97Zn86iz9zJkz5bUrV64s9zZtX7P94MGDxm3u3Lnl\ntY8fPy736hy/p6d+LfnatWvLa//yc/xfwp0fQokfQokfQokfQokfQokfQokfQjnn74KvX7+W+717\n9zr681++fNm4TZkypbz28+fP5d72bv23b9+We/V7AB8+fCiv3bhxY7m/evWq3BcsWNC4tf17SeDO\nD6HED6HED6HED6HED6HED6HED6Gc83dB27vv287a2549r87iq+fpf8Tz58/Lffz4+n+h6nshRkZG\nymufPXtW7m3P3J84caJxmzNnTnltAnd+CCV+CCV+CCV+CCV+CCV+COWorwvu3r3b0fWDg4Pl/uXL\nl8Zt8uTJ5bWzZs0q9+r11z097Y8rr1+/vnG7ceNGeW3b18fPmzev3FevXt24Jbyau407P4QSP4QS\nP4QSP4QSP4QSP4QSP4TqbTtL/cW6+mF/izVr1pT7rVu3yv3kyZPlvmHDhsZt6tSp5bVHjx4t9+Hh\n4XI/depUuT99+rTcK22vDR8aGir3tr/7GPZDv8Tgzg+hxA+hxA+hxA+hxA+hxA+hxA+hnPP/AqOj\no+U+ffr0cn///n259/f3l3v1TP6BAwfKa/ft21fubV+j3Yl169aV+5UrV8q97bXhwZzzA83ED6HE\nD6HED6HED6HED6HED6Gc83fB3r17y/3IkSO/7bPb3k//u//7L1u2rHG7fv16eW3w8/idcs4PNBM/\nhBI/hBI/hBI/hBI/hHLU1wVtj/zOnj273N+9e1fubV+TXZk5c2a5L1q0qNzPnj1b7nPnzm3cfE32\nb+OoD2gmfgglfgglfgglfgglfgglfgjlnP8v8O3bt3Jv+yrqt2/fNm4fP34sr12xYkW59/X1dbTz\nRzjnB5qJH0KJH0KJH0KJH0KJH0KJH0I554exxzk/0Ez8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8\nEEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8\nEEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EEr8EGp8lz+vt8ufBzRw54dQ4odQ4odQ4odQ4odQ4odQ\n4odQ4odQ4odQ4odQ4odQ4odQ4odQ4odQ4odQ4odQ4odQ4odQ4odQ4odQ4odQ4odQ4odQ/wACq0pZ\nA+Nj6gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f981c1d5eb8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(make_digit(digit=6), cmap='gray_r', interpolation='nearest')\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
